[feat] HSTUMatch: scalar item-tower export view#518
Conversation
…nce prefix HSTUMatch is the only model whose candidate side is a grouped sequence sub-feature; PR alibaba#506 papered over this by writing `attr_fields: "cand_seq__video_id"` in the test config and docs -- the *flattened parquet column name*. Every other sampler config uses the bare sub-feature name (`attr_fields: "item_id"` in DSSM/MIND/ TDM) because their candidate is a top-level feature where bare-name == flattened-name. The HSTUMatch outlier leaks DataParser's `{sequence_name}__{sub_feature}` flattening convention into user-facing config. Resolve at the dataset boundary: when `item_id_field` carries a qualified `{sequence_name}__{sub_feature}` form, derive the prefix and rewrite any bare `attr_fields` entries to the flattened form before constructing the sampler. Deep-copy the sampler sub-message so the original `data_config` is not mutated. The sampler then sees fully-qualified names just like today; `_valid_attr_names`, `_attr_types`, the sampled output dict, and `_merge_sampled_features` all work unchanged. Literal matches win, so already-qualified configs continue to work, and DSSM/MIND/TDM (where `item_id_field` is bare) skip the resolution branch entirely. Test config + doc switch to: attr_fields: "video_id" (bare sub-feature name) item_id_field: "cand_seq__video_id" (qualified; doubles as the sequence_name source) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…r export view Add a module-level helper that projects a grouped sequence sub-feature (e.g. `cand_seq__video_id`) into a top-level scalar `FeatureConfig` suitable for `create_features()` to construct as a non-sequence feature named `video_id`. The grouped sub-feature carries `SeqFeatureConfig`; the helper rewraps the contained oneof message (id_feature / raw_feature / combo_feature / ... — generic across the entire oneof) as a fresh `FeatureConfig`. It also materializes the source feature's effective `default_value` and `value_dim` onto the scalar proto: a sequence sub-feature with no explicit `default_value` resolves to `"0"` and `value_dim` to `1` (feature.py:556, 515-517), but a scalar with no explicit values would fall back to `""` / `0`. Without materialization the exported scalar feature would drift from the training semantics. Used by the HSTUMatchItemTower scalar export view (next commit) to swap from `cand_seq__video_id` (jagged) to `video_id` (scalar) without mutating the training feature objects. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Make the item tower own its training/export view switch:
- `__init__` introduces `self._cand_key = "{group_name}.sequence"`
(training view). `forward()` reads `grouped_features[self._cand_key]`
instead of the hard-coded `.sequence`.
- `set_is_inference(True)` projects each grouped sequence sub-feature
into a top-level scalar `FeatureConfig` via the helper from the
previous commit, rebuilds `self._features` as scalar features, swaps
`self._feature_groups` to a single `JAGGED_SEQUENCE` group over the
scalar names (so `EmbeddingGroup` emits `{group_name}.query` per row
instead of jagged-per-row), and flips `self._cand_key` to
`.query`. Idempotent: re-entering after the swap is a no-op.
`set_is_inference(False)` does not rebuild the training view --
callers must export from a `copy.deepcopy` of the item tower so the
training tower is preserved.
In `tzrec/utils/export_util.py::export_model_normal`, save the
wrapper's `model._feature_groups` onto the saved
`pipeline.config.model_config.feature_groups`. Today the saved config
keeps the original (training-shape) feature_groups even when the
exported `feature_configs` have been rewritten -- after the
HSTUMatchItemTower view swap, this mismatch shows up as scalar
`feature_configs` paired with stale `cand_seq__video_id` group names.
Non-match towers expose the same `_feature_groups` they had before, so
DSSM/MIND/etc. behaviour is unchanged.
`TowerWoEGWrapper` (match_model.py:481-486) already reads
`module._features` / `module._feature_groups` -- no wrapper changes
required.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Wires the HSTUMatchItemTower scalar export view (`candidate.query` per
row over scalar `video_id` instead of the training `candidate.sequence`
over jagged `cand_seq__video_id`) into the export pipeline. End-to-end:
* `tzrec/protos/pipeline.proto`: add `optional string item_input_path
= 10` on `EasyRecConfig`. Item-only table (one row per item, scalar
export-view schema) for recall-model item-tower export.
* `tzrec/utils/export_util.py::export_model_normal` (and the
`_aot` variant): when `model._tower_name == "item_tower"` and
`pipeline_config.item_input_path` is set, point the predict
dataloader at it instead of `train_input_path`, and clear
`data_config.sampler` so the predict path doesn't launch GraphLearn
for item-only rows. Also save `model._feature_groups` onto
`pipeline_config.model_config.feature_groups` so the exported config
is internally consistent (scalar feature_configs paired with scalar
feature_groups, not stale training groups).
* `tzrec/models/model.py::ScriptWrapper.__init__`: propagate
`_tower_name` from the inner module so `export_util.py` can see it
through the wrapper.
* `tzrec/main.py::export`: for match-model item towers, build the
wrapper around `copy.deepcopy(module).set_is_inference(True)` so the
view flip happens BEFORE `TowerWoEGWrapper` constructs its
`EmbeddingGroup`. The original training tower is untouched.
* `tzrec/models/hstu.py::HSTUMatchItemTower.set_is_inference`: drop
the `super().set_is_inference()` call -- `MatchTowerWoEG` derives
from `nn.Module`, not `BaseModule`. The `_is_inference` attribute is
propagated to sub-modules separately by `ScriptWrapper.set_is_inference`
via `recursive_setattr` during export.
* `tzrec/features/feature.py::project_grouped_sequence_feature_to_scalar`:
drop the materialization of sequence-effective `default_value` /
`value_dim` onto the scalar proto. The scalar-mode defaults
(empty / 0) are intentional: at item-export predict time the parquet
provides one value per row, and PyFG's `id_feature` operator
expects the scalar (not sequence) input shape.
* `tzrec/tests/configs/hstu_kuairand_1k.config`: set
`item_input_path: "data/test/kuairand-1k-match-item-c1.parquet"`.
* `tzrec/tests/match_integration_test.py::test_hstu_with_fg_train_eval`:
extend the body to also exercise `test_export` (AOT export with
`ENABLE_AOT=1` + `DISABLE_MMA_V3=1` -- same pattern as dlrm_hstu),
`test_predict` on the item tower (reading the new
`kuairand-1k-match-item-c1.parquet`, emitting `item_tower_emb`), and
`test_predict` on the user tower (reading the eval parquet for
the sequence-shape user-side path). Assert both
`export/{user,item}/scripted_sparse_model.pt` exist.
The new test fixture `data/test/kuairand-1k-match-item-c1.parquet`
(1000 rows, single `video_id: int64` column; md5
8dcadabdc3e9049ed9c2250565b4b134) is built locally by
`experiments/hstu_match/build_kuairand_fixtures.py`; the `ci_data.sh`
wget line will land after the user uploads it to OSS.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…wer scalar export test Follows the previous commit (HSTUMatch item-tower scalar export view + item_input_path routing). 1000-row scalar item parquet; consumed by `test_hstu_with_fg_train_eval` for the item-tower predict step. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…eq_field_delims
Replace the global `_seq_field_delims: Dict[str, str]` (every sequence
input -> delim) with two candidate-side fields derived from the
matching feature config:
_sampler_seq_delim -- "" if item_id_field isn't a sequence input;
the parent feature's sequence_delim otherwise
(covers top-level sequence_id_feature AND
grouped sequence sub-features).
_sampler_seq_prefix -- "" if item_id_field isn't grouped-flattened;
f"{sequence_name}{_underline}" otherwise
(RTP-safe; uses feature._underline directly
rather than guessing "__" vs "_").
Lookup is by `item_id_field in feature.sequence_input_names`, which
returns `[feature.name]` for top-level sequence features and the
flattened input names for grouped sub-features -- correctly matching
both cases the old `_seq_field_delims` dict covered.
Three downstream simplifications:
* `launch_sampler_cluster`: collapse the two checks into one
`if self._sampler_seq_delim:` gate. The bare-name prefix-resolve
runs only when `_sampler_seq_prefix` is set (grouped case). The
outer-list strip is scoped to candidate-sequence attrs
(`{a for a in attr_fields if a.startswith(seq_prefix)}`) -- top-level
sequence (prefix="") matches all attr_fields; grouped sequence picks
out just the seq-prefixed subset; non-sequence item-side attrs from
the same lookup feature (e.g. `cat_map`) are correctly excluded.
* `utils.build_sampler_input`: signature change from
`seq_field_delims: Dict[str, str]` to `seq_delim: str` -- callers
pass `self._sampler_seq_delim` (empty when not applicable).
* `_merge_sampled_features`: drop the per-key dict lookup; single
`self._sampler_seq_delim` applies uniformly because
`sampled.keys()` is always a subset of `attr_fields` (candidate-side
only) after the prefix-resolve step.
DSSM/MIND/TDM behaviour unchanged: when item_id_field is a top-level
scalar, `_sampler_seq_delim` stays empty and every branch degrades to
the pre-refactor scalar path.
Tests:
* dataset_test.py: replace `_seq_field_delims` membership assertions
with `_sampler_seq_delim` / `_sampler_seq_prefix` checks on the
same lookup_feature multi-attr-strip case.
* utils_test.py: update `build_sampler_input` test kwargs and the
empty-delim-passthrough case.
Doc cleanup: trim the trailing "dataset 层会..." sentence in
`hstu_match.md` (internal mechanics no longer accurate) and reword
"带 sequence_name 的全限定名" to "带 sequence_name 的序列前缀的名".
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…eq defaults; hoist set_is_inference
Three review-prompted fixes:
1. project_grouped_sequence_feature_to_scalar() now materializes the
source feature's effective default_value and value_dim onto the
scalar proto. Without this, an id_feature projected from a sequence
sub-feature (sequence-effective defaults "0" / 1) silently flipped
to scalar defaults ("" / 0) -- behaviorally divergent from the
training sub-feature it was projected from. The earlier removal
chased a misdiagnosed FG failure that was actually a _tower_name
propagation bug (fixed in commit 0c1dcf1).
2. HSTUMatchItemTower exposes `features` and `feature_groups` as lazy
properties driven by `_is_inference`, instead of mutating
`_features` / `_feature_groups` / `_cand_key` inside
`set_is_inference(True)`. The scalar view is built once on first
property access and cached (`_features_scalar` /
`_feature_groups_scalar`); the training view stays immutable on
the parent's `self._features`, so `set_is_inference(False)`
reverts cleanly. `forward()` derives the candidate key inline
(`{group}.sequence` vs `{group}.query`) from the flag -- no
cached `_cand_key` attribute.
`MatchTowerWoEG` exposes default `features` / `feature_groups`
properties (forwarding to the underscore fields). Wrappers
(`TowerWrapper`, `TowerWoEGWrapper`, `ScriptWrapper`) read via the
properties and expose their own snapshot properties. Non-HSTUMatch
towers are unaffected -- the default properties match the
pre-refactor direct attribute reads.
3. set_is_inference(True) is hoisted to before every InferWrapper()
in tzrec/main.py::export, removed from the three post-wrap call
sites in tzrec/utils/export_util.py (lines 201, 743, 1065). The
recursive_setattr from BaseModule.set_is_inference propagates the
flag down to all sub-modules including the inner towers; wrappers
are then constructed with the inference-mode view already
established (so TowerWoEGWrapper's EmbeddingGroup is built off
HSTUMatchItemTower's scalar features). Drop the per-item-tower
deep-copy in main.py -- mutation is gone, the flag is the only
state change.
The HSTUMatchTest JIT_SCRIPT branch no longer flips `_is_inference`
(it was the source of a view-toggle conflict: the model's own
EmbeddingGroup emits `.sequence`-shaped grouped_features but
`_is_inference=True` would make the item tower read `.query`,
KeyError). JIT compiles both branches; the training-shape batch
runs the training branch.
The feature projection unit test reverts to assert materialization:
`scalar_cfg.id_feature.default_value == "0"`, `value_dim == 1`.
Non-functional for DSSM/MIND/TDM: properties default to underscore
fields, set_is_inference flow is unchanged for non-HSTUMatch towers.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…rop hasattr fallback in ScriptWrapper Three follow-on simplifications on top of commit 8cd5d3e: 1. `HSTUMatchItemTower.set_is_inference` override removed -- now dead code. `BaseModule.set_is_inference` (called once on the full model in `main.py::export`) uses `recursive_setattr` which sets the `_is_inference` attribute on every sub-module directly; it never calls sub-modules' methods. With the lazy-property design, the flag flip on `HSTUMatchItemTower._is_inference` IS the toggle -- no extra method needed. 2. `TowerWrapper` / `TowerWoEGWrapper` now read `features` / `feature_groups` lazily via `getattr(self, self._tower_name)` -- no construction-time snapshot of those metadata fields. The EmbeddingGroup (which owns nn.Parameters) still snapshots at construction; the metadata properties stay live so they reflect whatever view the inner tower currently exposes. 3. `BaseModel` and `TDMEmbedding` get explicit `features` / `feature_groups` properties (default reads of `_features` / `_feature_groups`). With every wrapped model exposing the property, `ScriptWrapper` drops the `_features_from` / `_feature_groups_from` hasattr-fallback helpers and just reads `self.model.features` / `self.model.feature_groups`. The wrapper also no longer snapshots into `self._features` / `self._feature_groups` -- those properties forward live to `self.model.features` / `self.model.feature_groups`. `tzrec/utils/export_util.py` and `tzrec/acc/aot_utils.py` migrated from `model._features` / `model._feature_groups` to the property API (8 occurrences total). Non-functional for DSSM/MIND/TDM: the new default properties on BaseModel + TDMEmbedding just return the underscore fields, identical to today's direct attribute reads. The HSTUMatch integration test passes end-to-end (~289-313s on local A10). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Three-way disambiguation: literal match in field_names wins, so
already-qualified attrs ("cand_seq__video_id"), top-level item-side
attrs in mixed configs (lookup_feature's "cat_map" whose
sequence_fields exclude it), and bare candidate sub-features
("video_id" -> "cand_seq__video_id") all resolve correctly.
Verified by `tzrec.datasets.dataset_test::test_launch_sampler_cluster_multi_attr_strip_decision_matrix`
which exercises the mixed-attrs case (attr_fields=["cat_map", "click_seq__cat_key"]);
unconditional prefix would break "cat_map" by sending it through as
"click_seq__cat_map" (not in parquet schema).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… (feature-config only)
Replace the parquet-schema membership check (`prefix + a in field_names`)
with a feature-config-derived check (`prefix + a in self._sampler_seq_inputs`).
`_sampler_seq_inputs` is the union of `sequence_input_names` across all
candidate-side grouped sequence sub-features (those sharing
`sequence_name` with `item_id_field`'s parent), precomputed at
`BaseDataset.__init__` time. Authoritative source: feature configs.
Three semantic consequences vs the old `field_names` check:
- More precise: an attr is prefixed only when `prefix + a` is a known
candidate-sequence input, not just any parquet column. Spurious
same-name parquet columns can't trick the resolution.
- Same outcome for the existing test cases: HSTUMatch ("video_id" ->
"cand_seq__video_id" because "cand_seq__video_id" is a candidate
sequence input); lookup_feature mixed config (`cat_map` stays as-is
because "click_seq__cat_map" is NOT a sequence input; `click_seq__cat_key`
stays as-is because double-prefixing lands outside the set).
- No parquet-schema dependency: the resolution doesn't need to enumerate
`sampler_fields` per launch_sampler_cluster call.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…a_input_path CLI flag `item_input_path` was the wrong abstraction in two ways: 1. It lived in `pipeline.proto` as a model-config field, but it isn't a model-shape concern -- it's a per-invocation override of the export dataloader's input path. Storing it in the proto forces every recall pipeline to ship a config edit just to point export at a different table. 2. The name embedded a hard policy (`item_*`) into a layer (`export_model`) that doesn't care about towers. This commit: - Removes `optional string item_input_path` from `pipeline.proto`. - Adds `--data_input_path` CLI flag on `tzrec/export.py` and threads it through `main.py::export` -> `export_util.export_model` -> `export_model_normal` / `export_rtp_model`. - Renames `item_input_path` -> `data_input_path` in `export_model`. The semantics are now generic: if `data_input_path` is non-empty it overrides `train_input_path` for the predict-mode dataloader; otherwise fall back to the pipeline config. No tower-coupling at this layer. - Drops the `is_item_tower` gate inside `export_util.py` -- predict-mode bypasses the sampler anyway, so the previous `data_config.ClearField( "sampler")` branch was dead too. - Policy "only the item tower receives the override" lives in `main.py::export`'s match-tower loop, where it belongs. - Integration test (`MatchIntegrationTest.test_hstu_with_fg_train_eval`) passes the override via `utils.test_export(..., data_input_path=...)`, matching how real callers will use the CLI flag. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
| self._feature_groups = self.model._feature_groups | ||
| # Propagate tower identity (set by TowerWoEGWrapper / TowerWrapper) | ||
| # so export_util.py can route item-tower export through | ||
| # `pipeline_config.item_input_path` instead of `train_input_path`. |
There was a problem hiding this comment.
Stale reference: item_input_path was moved out of pipeline_config in the precursor commit (1e885b53) and is now the --data_input_path CLI flag. Suggest:
| # `pipeline_config.item_input_path` instead of `train_input_path`. | |
| # so export_util.py can route item-tower export through the | |
| # `--data_input_path` CLI override instead of `train_input_path`. |
| # The EmbeddingGroup itself owns nn.Parameters so it must be a | ||
| # construction-time snapshot; the `features`/`feature_groups` | ||
| # properties below stay live via lazy reads on the inner tower. | ||
| self.embedding_group = EmbeddingGroup(module.features, module.feature_groups) |
There was a problem hiding this comment.
This EmbeddingGroup snapshots the current view at construction time, which only works because the caller in main.py::export already ran model.set_is_inference(True) before wrapping. That invariant is undocumented at this call site and silent if violated (EG built off training jagged features → predict path uses .query key → KeyError at inference). Consider an explicit assertion to lock the contract:
if hasattr(module, "_is_inference"):
assert module._is_inference, (
"TowerWoEGWrapper requires set_is_inference(True) to have run on "
"the inner module before wrapping; otherwise EmbeddingGroup is "
"built off the wrong view."
)
Review summarySolid refactor with thoughtful comments and integration coverage. A few noteworthy points (left as inline comments): Likely bug
Doc
Tests / contracts
No concerns on perf (the hot-path lookup shrinks from dict-keyed to a single attribute load), the new |
Commit 35b18cb made `TowerWrapper` / `ScriptWrapper` read view-state via `@property` (`getattr(self, self._tower_name).features`), and added the property to `MatchTowerWoEG`, `BaseModel`, `TDMEmbedding`. But `MatchTower` (the parent of DSSMTower / DATTower / MIND*Tower) was missed: it sets the underscore fields in `__init__` but never exposes the no-underscore property. `ScriptWrapper.__init__` then crashes during the match-tower wrap with `AttributeError: features` for any non-HSTU match model. Add the two properties to `MatchTower`, mirroring `MatchTowerWoEG`. No behavioural change for HSTU; restores DSSM / MIND / DAT export. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…lar-item-export # Conflicts: # docs/source/models/hstu_match.md # tzrec/datasets/dataset.py # tzrec/datasets/dataset_test.py # tzrec/datasets/utils.py # tzrec/models/match_model.py # tzrec/models/model.py
…path - Revert dataset.py / dataset_test.py / utils.py to origin/master verbatim (post-merge they were a superset; user requested keeping master as-is). - Revert match_model.py `MatchTower` / `MatchTowerWoEG` / wrapper features/feature_groups property docstrings to master wording. - Simplify verbose PR-added comments across main.py, export_util.py, hstu.py, feature.py to one short line each. - Drop the multi-line "DISABLE_MMA_V3 / ENABLE_AOT" rationale block from the HSTU match integration test header. - CLI flag: rename `--data_input_path` back to `--item_input_path` on `tzrec/export.py`; `main.export()` accepts `item_input_path`. The internal `export_model()` function keeps the generic `data_input_path` parameter name. - Combine the three projection tests (`test_id_feature_projection_materializes_seq_defaults`, `test_projection_passes_through_create_features_as_scalar`, `test_raw_feature_projection_generic_oneof`) into a single `test_projection_materializes_defaults_and_passes_through_create_features` covering id_feature defaults materialization, create_features pass-through, and raw_feature oneof coverage. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…contract - `export_util.py`: drop `hasattr(model, "_feature_groups")` gate at the pipeline.config feature_groups write. Post-refactor, no wrapper sets the underscore field as an instance attribute (replaced with `@property`), so the gate never fired. The exported pipeline.config for the HSTU item tower kept stale training-view feature_groups while the feature_configs were the scalar projection -- predict reading the exported config would see a name mismatch. Write `model.feature_groups` unconditionally; the property cascade guarantees the right view on every model that reaches export. - `hstu_test.py`: migrate `_build_model` / `_build_batch` to the grouped-sequence pattern (`sequence_feature` wrapping sub-features named `video_id` in both `uih_seq` and `cand_seq`, sharing one embedding table via aligned `num_buckets` / `embedding_dim` / `embedding_name`). Inline scalar-view assertions into the existing `test_hstu_match` as the final step: stash `item_tower` before the graph_type branch wraps `hstu`, flip `set_is_inference(True)`, and assert the projected scalar names / non-grouped-sequence flag / feature_groups feature_names + group_name. Locks the lazy view contract HSTUMatchItemTower depends on for export. - `model.py`: drop the now-unused `_tower_name` propagation in `ScriptWrapper.__init__`. Originally load-bearing for an export-util routing path that's been replaced by the `--item_input_path` CLI flag (decided at `main.py::export`). No external reader of `ScriptWrapper._tower_name` remains. - `match_integration_test.py`: drop `hstu_env = "DISABLE_MMA_V3=1"` (no longer required by the underlying Triton path) and the inline set_is_inference note in `export_util.py`. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…ower access - `export_util.py`: remove the `pipeline_config.model_config.ClearField` + `extend(model.feature_groups)` block. Predict reads only `pipeline_config.feature_configs` (see `main.py::predict` at the scripted-model path) -- the scripted model has its EmbeddingGroup baked in, so `model_config.feature_groups` in the exported config is dead metadata. No downstream consumer. - `hstu_test.py`: drop the unnecessary `item_tower = hstu.item_tower` stash; `hstu` itself is preserved across `create_test_model` / `TrainWrapper` wrapping (assigned to `hstu_wrapped`), so the scalar-view assertions can read `hstu.item_tower` inline. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
| # Flip to inference *before* wrapping so view-dependent state | ||
| # (e.g. HSTUMatchItemTower's lazy properties, wrapper EmbeddingGroups) | ||
| # is snapshot from the scalar view. | ||
| model.set_is_inference(True) |
There was a problem hiding this comment.
Worth a short comment here naming the reason this must flip on all ranks pre-wrap: export_rtp_model runs init_process_group() / dist.barrier() / DistributedModelParallel collectives, which require every rank to share the same model topology. If set_is_inference ever alters sharded-module shape (as it now does, via HSTUMatchItemTower.features's scalar projection), a rank-0-only flip would diverge ranks. A one-liner would prevent a future refactor from moving it back inside the rank-0 branch.
Also, asymmetry to watch for: the outer ScriptWrapper is constructed after the flip, so its own _is_inference is False while the inner self.model._is_inference is True. Functionally fine (nothing reads the outer flag), but surprising — same comment can call this out.
| # Lazy caches for the scalar export view (populated on first | ||
| # property access after `set_is_inference(True)`). | ||
| self._features_scalar: Optional[List[BaseFeature]] = None | ||
| self._feature_groups_scalar: Optional[List[model_pb2.FeatureGroupConfig]] = None |
There was a problem hiding this comment.
Worth documenting the cache-validity invariant: set_is_inference(False) flips the flag back but does not clear _features_scalar / _feature_groups_scalar. Reuse on a subsequent flip is correct because _features is immutable post-__init__, but a future maintainer reading this will reasonably wonder whether the cache should be invalidated on revert. One line ("cache valid for tower lifetime; _features is immutable") would prevent a well-intentioned "fix".
| if hasattr(dst_msg, "default_value") and not dst_msg.default_value: | ||
| dst_msg.default_value = feature.default_value | ||
| if hasattr(dst_msg, "value_dim") and not dst_msg.HasField("value_dim"): | ||
| dst_msg.value_dim = feature.value_dim |
There was a problem hiding this comment.
Two things:
- The
hasattrguards are meaningful only forOverlapFeature(itsdefault_valueis commented out in the proto); for the other 11SeqFeatureConfig.featurevariants both fields are always present, so the guard reads as defensive code with no purpose. Worth namingOverlapFeaturein a comment. - Test gap: the "don't overwrite" branches (
if not dst_msg.default_value/not dst_msg.HasField("value_dim")) are not exercised. The test infeature_test.pyonly covers the "both unset → materialize from source" path. Consider adding a case where the inner proto already sets e.g.default_value="-1"orvalue_dim=4and asserting the materialization does not overwrite. Same forvalue_dim > 1onraw_feature— currently onlyvalue_dim == 1(default) is asserted.
| self.assertTrue(hstu.item_tower._is_inference) | ||
| scalar_features = hstu.item_tower.features | ||
| scalar_feature_groups = hstu.item_tower.feature_groups | ||
| self.assertEqual(scalar_features[0].name, "video_id") | ||
| self.assertFalse(scalar_features[0].is_grouped_sequence) | ||
| self.assertEqual(scalar_feature_groups[0].feature_names, ["video_id"]) | ||
| self.assertEqual(scalar_feature_groups[0].group_name, "candidate") |
There was a problem hiding this comment.
The scalar-view contract assertions are good, but the new forward() branch (hstu.py:240-241, the .query suffix path) is never invoked here — only via the AOT integration test in match_integration_test.py. If the wrapper's EmbeddingGroup ever emits the scalar group under a key that doesn't match self._group_name + ".query", this unit test won't catch it.
Cheap addition: build a fake grouped_features={"candidate.query": <tensor>} and call hstu.item_tower(grouped_features), asserting the output shape. Locks the suffix contract at unit-test latency rather than relying on the 30-min integration test to surface a regression.
| # item-tower-only; user tower falls back to `train_input_path`. | ||
| tower_data_input_path = ( | ||
| item_input_path if name == "item_tower" else None | ||
| ) |
There was a problem hiding this comment.
The user-tower fallback to train_input_path is asserted only implicitly by the integration test (which happens to predict the user tower over the eval parquet, not the train parquet). So if a refactor accidentally routed item_input_path to the user tower too, the integration test would still pass.
If you want to lock this in cheaply, a small unit test in main_test.py that monkey-patches export_model and asserts data_input_path per tower (item_tower → item_input_path, user_tower → None) would pin the policy.
Review summaryClean, well-scoped change. The two-view model ( No perf, security, or correctness concerns surfaced. Findings are all low — mostly clarity/documentation and minor test gaps:
Nothing blocking. Nice work threading |
…t_path Document the item-tower scalar export flow: `ENABLE_AOT=1` when using Triton kernel + `--item_input_path` pointing at a one-row-per-item parquet. Mirrors the style of `dlrm_hstu.md`'s 模型导出 section. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Summary
HSTUMatchItemTowergets a scalar export view (candidate.queryper row over scalarvideo_id) so the recall item tower can be exported against a one-row-per-item parquet, decoupled from the training-shape sequence rows intrain_input_path.Sits on top of #506 (HSTUMatch encoder refactor), #519 (features/feature_groups as @Property end-to-end), and #520 (sampler bare-name
attr_fields), all already on master.Two-view model
candidate.sequence(jagged)train_input_path(sequence rows)candidate.query(scalar)--item_input_path(scalar parquet)Key design decisions
Lazy view-swap on
HSTUMatchItemTower.features/feature_groupsare@property-driven by_is_inference: training reads the parent's grouped sub-features; export reads a lazily-built scalar projection cached in_features_scalar/_feature_groups_scalar.set_is_inference(True)is a cheap flag toggle, no structural mutation;set_is_inference(False)reverts cleanly.forward()derives the candidate key inline ({group}.sequencevs{group}.query) from the flag — no_cand_keyattribute.Materialized sequence defaults in
project_grouped_sequence_feature_to_scalar. Anid_featureprojected from sequence mode (effectivedefault_value="0"/value_dim=1) carries those defaults forward to the scalar proto; otherwise the export would silently flip to scalar defaults (""/0), diverging from the training sub-feature.Single
set_is_inferenceflow. Hoisted to one site inmain.py::exportbefore everyInferWrapper(...)call.recursive_setattrpropagates to every sub-module before any wrapper builds itsEmbeddingGroupoff view-dependent features. Removed from the three post-wrap call sites inexport_util.py.Item-tower input as a CLI flag.
--item_input_pathontzrec/export.pyplumbs throughmain.export()→export_model(data_input_path=...)→ predict-mode dataloader.main.py::exportapplies the policy that onlyname == "item_tower"receives the override; the user tower keeps readingtrain_input_path. No proto change —EasyRecConfigis untouched.Test plan
python -m unittest tzrec.features.feature_test.ProjectGroupedSequenceFeatureToScalarTest tzrec.models.hstu_test— green on local A10.test_hstu_matchnow uses the production grouped-sequence pattern (sequence_featurewrappingvideo_idsub-feature, shared embedding acrossuih_seq/cand_seqviaembedding_name) and asserts the scalar-view contract (bare names, non-grouped, scalarfeature_groups.feature_names) as the final step.python -m unittest tzrec.tests.match_integration_test.MatchIntegrationTest.test_hstu_with_fg_train_eval— train + eval + AOT export + item/user predict, green on local A10.python -m unittest tzrec.tests.match_integration_test.MatchIntegrationTest.test_{dssm,dat,mind}_*_train_eval_export— non-HSTU match regression, green.pre-commit run— clean on all touched files.Known limitations / follow-ups
sequence_featuresub-features. Top-levelsequence_id_feature/sequence_raw_featurewould raiseValueErrorfromproject_grouped_sequence_feature_to_scalarif reached; HSTUMatch doesn't use them today.🤖 Generated with Claude Code